package com.atguigu.state;

import com.atguigu.bean.WaterSensor;
import com.atguigu.function.WaterSensorMapFunction;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

import java.time.Duration;

public class KeyedAggregatingStateDemo {

    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);

        SingleOutputStreamOperator<WaterSensor> sensorDS = env
                .socketTextStream("10.75.186.206", 9999)
                .map(new WaterSensorMapFunction())
                .assignTimestampsAndWatermarks(
                        WatermarkStrategy.<WaterSensor>forBoundedOutOfOrderness(Duration.ofSeconds(3))
                                .withTimestampAssigner((w, ts) -> w.getTs() * 1000)
                );

        //  计算每种传感器的水位平均值
        sensorDS.keyBy(WaterSensor::getId)
                .process(new KeyedProcessFunction<String, WaterSensor, String>() {
                    AggregatingState<Integer, Double> aggregatingState;
                    @Override
                    public void open(Configuration parameters) throws Exception {
                        super.open(parameters);
                        aggregatingState = getRuntimeContext().getAggregatingState(
                                new AggregatingStateDescriptor<Integer, Tuple2<Integer, Integer>, Double>(
                                        "",
                                        new AggregateFunction<Integer, Tuple2<Integer, Integer>, Double>() {
                                            // 初始化
                                            @Override
                                            public Tuple2<Integer, Integer> createAccumulator() {
                                                return Tuple2.of(0, 0);
                                            }

                                            // 计算
                                            @Override
                                            public Tuple2<Integer, Integer> add(Integer vc, Tuple2<Integer, Integer> acc) {
                                                return Tuple2.of(vc + acc.f0, acc.f1 + 1);
                                            }

                                            // 最终结果
                                            @Override
                                            public Double getResult(Tuple2<Integer, Integer> acc) {
                                                return acc.f0 * 1D / acc.f1;
                                            }

                                            @Override
                                            public Tuple2<Integer, Integer> merge(Tuple2<Integer, Integer> integerIntegerTuple2, Tuple2<Integer, Integer> acc1) {
                                                return null;
                                            }
                                        },
                                        Types.TUPLE(Types.INT, Types.INT)
                                )
                        );
                    }

                    @Override
                    public void processElement(WaterSensor value, KeyedProcessFunction<String, WaterSensor, String>.Context ctx, Collector<String> out) throws Exception {
                        aggregatingState.add(value.getVc());
                        out.collect(value.getId() + "的平均水位是：" + aggregatingState.get());
                    }
                })
                .print();
        env.execute();
    }

}
